{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 7 - Federated Learning with FederatedDataset\n",
    "\n",
    "Here we introduce a new tool for using federated datasets. We have created a `FederatedDataset` class which is intended to be used like the PyTorch Dataset class, and is given to a federated data loader `FederatedDataLoader` which will iterate on it in a federated fashion.\n",
    "\n",
    "\n",
    "Authors:\n",
    "- Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)\n",
    "- Théo Ryffel - GitHub: [@LaRiffle](https://github.com/LaRiffle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the sandbox that we discovered last lesson"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as th\n",
    "import syft as sy\n",
    "sy.create_sandbox(globals(), verbose=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then search for a dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "boston_data = grid.search(\"#boston\", \"#data\")\n",
    "boston_target = grid.search(\"#boston\", \"#target\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load a model and an optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_features = boston_data['alice'][0].shape[1]\n",
    "n_targets = 1\n",
    "\n",
    "model = th.nn.Linear(n_features, n_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we cast the data fetched in a `FederatedDataset`. See the workers which hold part of the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cast the result in BaseDatasets\n",
    "datasets = []\n",
    "for worker in boston_data.keys():\n",
    "    dataset = sy.BaseDataset(boston_data[worker][0], boston_target[worker][0])\n",
    "    datasets.append(dataset)\n",
    "\n",
    "# Build the FederatedDataset object\n",
    "dataset = sy.FederatedDataset(datasets)\n",
    "print(dataset.workers)\n",
    "optimizers = {}\n",
    "for worker in dataset.workers:\n",
    "    optimizers[worker] = th.optim.Adam(params=model.parameters(),lr=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We put it in a `FederatedDataLoader` and specify options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = sy.FederatedDataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And finally we iterate over epochs. You can see how similar this is compared to pure and local PyTorch training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, epochs + 1):\n",
    "    loss_accum = 0\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        model.send(data.location)\n",
    "        \n",
    "        optimizer = optimizers[data.location.id]\n",
    "        optimizer.zero_grad()\n",
    "        pred = model(data)\n",
    "        loss = ((pred.view(-1) - target)**2).mean()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        model.get()\n",
    "        loss = loss.get()\n",
    "        \n",
    "        loss_accum += float(loss)\n",
    "        \n",
    "        if batch_idx % 8 == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tBatch loss: {:.6f}'.format(\n",
    "                epoch, batch_idx, len(train_loader),\n",
    "                       100. * batch_idx / len(train_loader), loss.item()))            \n",
    "            \n",
    "    print('Total loss', loss_accum)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for GitHub issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
